Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713cspades wants to merge 5 commits intoNVIDIA:mainfrom
Conversation
50da1dc to
925d022
Compare
Greptile SummaryThis PR adds DCP (Distributed Checkpoint) compatibility for FSDP2+TP strided sharding across all Key issues found:
Confidence Score: 2/5
Last reviewed commit: fcdd5bd |
| if args.sharding_dims: | ||
| assert len(args.sharding_dims) <= 2 | ||
| assert len(args.sharding_dims) <= 3 | ||
| if len(args.sharding_dims) >= 3: | ||
| # Set the TP size in args. | ||
| args.tp_size = args.sharding_dims[2] | ||
| else: | ||
| args.tp_size = 1 | ||
| return args |
There was a problem hiding this comment.
args.sharding_dims not guarded against None
At line 153, len(args.sharding_dims) is called unconditionally, but args.sharding_dims can be None when the --sharding-dims flag is omitted (since the argument is not marked required=True and uses nargs="+"). This will raise TypeError: object of type 'NoneType' has no len().
The if len(args.sharding_dims) >= 3: block should be nested inside the existing if args.sharding_dims: guard:
| if args.sharding_dims: | |
| assert len(args.sharding_dims) <= 2 | |
| assert len(args.sharding_dims) <= 3 | |
| if len(args.sharding_dims) >= 3: | |
| # Set the TP size in args. | |
| args.tp_size = args.sharding_dims[2] | |
| else: | |
| args.tp_size = 1 | |
| return args | |
| if args.sharding_dims: | |
| assert len(args.sharding_dims) <= 3 | |
| if len(args.sharding_dims) >= 3: | |
| # Set the TP size in args. | |
| args.tp_size = args.sharding_dims[2] | |
| else: | |
| args.tp_size = 1 | |
| else: | |
| args.tp_size = 1 |
| self.fc1_bias = _convert_param_to_dtensor_param( | ||
| self.fc1_bias, tp_mesh, placements=(Shard(dim=0),) | ||
| ) | ||
| # FC2 Weight -> Row-Parallel -> Shard(dim=1) | ||
| self.fc2_weight = _convert_param_to_dtensor_param( | ||
| self.fc2_weight, tp_mesh, placements=(Shard(dim=1),) | ||
| ) | ||
| # LN & FC2 Bias -> Replicate() | ||
| self.fc2_bias = _convert_param_to_dtensor_param( | ||
| self.fc2_bias, tp_mesh, placements=(Replicate(),) | ||
| ) |
There was a problem hiding this comment.
Bias converted unconditionally when use_bias=False
When use_bias=False, self.fc1_bias and self.fc2_bias are initialized as plain torch.Tensor objects (not nn.Parameter, see lines 1940 and 1958):
else:
self.fc1_bias = torch.Tensor().to(dtype=params_dtype, device=device)Calling _convert_param_to_dtensor_param on them returns nn.Parameter(DTensor.from_local(...)). When this is then assigned back via self.fc1_bias = new_param, PyTorch's Module.__setattr__ will detect the nn.Parameter type and register the bias as a named module parameter, even though biases are disabled. This would pollute model.named_parameters(), the optimizer parameter list, and checkpoint state.
The fix is to guard these two conversions behind if self.use_bias:, following the same pattern already used for layer_norm_bias at line 2091.
| weight_mesh : Optional[DeviceMesh] | ||
| Not used for DotProductAttention as there are no quantized weights. | ||
| """ | ||
| warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}") |
There was a problem hiding this comment.
Spurious warning when weight_mesh is None
warnings.warn(...) is emitted unconditionally every time set_device_mesh is called, even when weight_mesh=None. The calling code invokes this method whenever tp_mesh is not None or weight_mesh is not None, so a normal call with only tp_mesh provided will generate a misleading warning like "weight_mesh not necessary for DotProductAttention: None".
The warning should only fire when the caller explicitly passes a non-None weight_mesh. The same spurious warning exists in transformer_engine/pytorch/module/layernorm.py (line 171) and transformer_engine/pytorch/module/rmsnorm.py (line 174).
| warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}") | |
| if weight_mesh is not None: | |
| warnings.warn(f"weight_mesh not necessary for {self.__class__.__name__}: {weight_mesh}") |
Signed-off-by: Cory Ye <cye@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Cory Ye <cye@nvidia.com>
4ec2947 to
dbb9d14
Compare
| @@ -30,6 +38,61 @@ | |||
| LOCAL_RANK = None | |||
|
|
|||
|
|
|||
| @dataclass | |||
| class AppState(Stateful): | |||
| """AppState for FSDP2 checkpoint via Torch DCP. | |||
|
|
|||
| Adapted from https://docs.pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html | |||
| """ | |||
|
|
|||
| model: torch.nn.Module | |||
| optimizer: torch.optim.Optimizer | |||
|
|
|||
| def state_dict(self): | |||
| """ | |||
| Get the state dict for the model, optimizer, scheduler, and step. | |||
| This factory both retrieves the model state dictionary when saving | |||
| checkpoints and initializes a destination for the state read from | |||
| DCP checkpoint files when loading checkpoints. | |||
| """ | |||
| model_state_dict, optimizer_state_dict = get_state_dict(self.model, self.optimizer) | |||
| for fqn in list(model_state_dict.keys()): | |||
| # Get the model parameter. | |||
| model_param = model_state_dict[fqn] | |||
| if isinstance(model_param, DTensor): | |||
| model_param = model_param.to_local() | |||
| if model_param.numel() == 0 and fqn in optimizer_state_dict["state"]: | |||
| # Empty model parameter. Clear the associated optimizer state | |||
| # when initializing the optimizer state upon DCP load, because | |||
| # empty optimizer state DTensors are not checkpointed with DCP, | |||
| # yet get_state_dict / _init_optim_state produce empty Tensors. | |||
| # TransformerEngine uses empty Tensors for dummy Parameters. | |||
| optimizer_state_dict["state"][fqn] = {} | |||
| if fqn.endswith("._extra_state"): | |||
| # Evict `_extra_state` quantization data from model checkpoint. | |||
| model_state_dict.pop(fqn) | |||
| return { | |||
| "model": model_state_dict, | |||
| "optim": optimizer_state_dict, | |||
| } | |||
|
|
|||
| def load_state_dict(self, state_dict: dict): | |||
| """ | |||
| Load the state dict for the model, optimizer, scheduler, and step. | |||
| Given the checkpoint-loaded state_dict, set the state of the model, | |||
| optimizer, scheduler, step, and epoch to the values in state_dict. | |||
| """ | |||
| set_state_dict( | |||
| self.model, | |||
| self.optimizer, | |||
| model_state_dict=state_dict["model"], | |||
| optim_state_dict=state_dict["optim"], | |||
| # Non-strict checkpoint loading ignores empty optimizer states, | |||
| # skips loading non-FP8 checkpoint weights (e.g. _extra_state). | |||
| options=StateDictOptions(strict=False), | |||
| ) | |||
|
|
|||
|
|
|||
| def dist_print(msg): | |||
| if LOCAL_RANK == 0: | |||
There was a problem hiding this comment.
DCP checkpoint functionality is not exercised in the test
The AppState class (lines 42–93) and DCP checkpoint operations (save, load, get_state_dict, set_state_dict) are imported and fully implemented, but the training loop in _train() (lines 480–490) does not call any checkpoint save/load operations. The function ends at line 497 with dist.destroy_process_group() and no checkpoint round-trip.
Since the PR title is "Add DCP compatibility for FSDP2-TP sharding," the checkpoint functionality is the headline feature. Without an actual save/load call in the test, neither the AppState.state_dict() eviction logic nor the set_state_dict(strict=False) reload path is validated.
Recommendation: Add a checkpoint save/load round-trip after the training loop (before dist.destroy_process_group()) to exercise the DCP functionality, or explicitly note in the test docstring that DCP round-trip testing is deferred to integration tests.
There was a problem hiding this comment.
Working on it! + GroupedLinear test case.
Signed-off-by: Cory Ye <cye@nvidia.com>
| CKPT_DIR = ( | ||
| Path(SHARED_TMP_DIR) | ||
| / "run_fsdp2_model" | ||
| / f"dcp-{'_'.join(str(x) for x in args.sharding_dims)}-{args.layer_type}-{args.recipe}-fp8_init_{args.fp8_init}" |
There was a problem hiding this comment.
args.sharding_dims can be None when --sharding-dims is not passed (the argument uses nargs="+" without required=True). The f-string on this line iterates over it directly, which will raise TypeError: 'NoneType' object is not iterable in that case.
| / f"dcp-{'_'.join(str(x) for x in args.sharding_dims)}-{args.layer_type}-{args.recipe}-fp8_init_{args.fp8_init}" | |
| / f"dcp-{'_'.join(str(x) for x in (args.sharding_dims or []))}-{args.layer_type}-{args.recipe}-fp8_init_{args.fp8_init}" |
| grouped_param = _convert_param_to_dtensor_param( | ||
| grouped_param, | ||
| device_mesh=dtensor_member_param.device_mesh, | ||
| placements=dtensor_member_param.placements, | ||
| # DTensor / DCP will view this as a TP-sharded 3-D Tensor. | ||
| shape=(self.num_gemms, self.out_features, self.in_features), | ||
| # Default Stride: (out*in, in, 1) | ||
| stride=None, |
There was a problem hiding this comment.
dtensor_member_param.placements was assigned in set_device_mesh relative to the 2-D weight shape (out_features, in_features):
- Column-parallel →
Shard(dim=0)(shard overout_features) - Row-parallel →
Shard(dim=1)(shard overin_features)
But the global shape here is the 3-D tensor (num_gemms, out_features, in_features). Reusing the same placements verbatim means:
Shard(dim=0)now refers to thenum_gemmsaxis — wrong, each TP rank holds all gemms.Shard(dim=1)would refer toout_featureswhen the row-parallel split is actually onin_features(dim=2).
As a result, when DCP reconstructs the full checkpoint from the local shards it will use the wrong axis, producing silently corrupted weight tensors.
The shard dimensions need to be incremented by 1 to account for the prepended num_gemms dimension:
from torch.distributed.tensor import Shard as _Shard, Replicate as _Replicate
adjusted_placements = tuple(
_Shard(p.dim + 1) if isinstance(p, _Shard) else p
for p in dtensor_member_param.placements
)
grouped_param = _convert_param_to_dtensor_param(
grouped_param,
device_mesh=dtensor_member_param.device_mesh,
placements=adjusted_placements,
shape=(self.num_gemms, self.out_features, self.in_features),
stride=None,
)There was a problem hiding this comment.
Okay, this is an impressive catch. I wrote this code too quickly. Will fix!
Summary
DTensorparameters with FP8, across allTransformerEngineBaseModule(s).Details
"shard"was the presumed weight sharding sub-mesh in theDTensor.device_mesh. Now, users can precisely specify their own custom weight-shardingDeviceMeshfor per-tensoramax_reduction_groupvia theset_device_meshAPI.Testing
mainandcspades:cye/fsdp2-tp-dcpso we can assume it is not associated to my change: https://github.com/NVIDIA/Megatron-LM/actions/runs/22637904520/job/65636890955?pr=3661 (TransformerEnginemain)mainvs.cspades:cye/fsdp2-tp-dcpwith Megatron-LMmainon PyTorch25.11Type of change
Changes
Please list the changes introduced in this PR:
Checklist: